{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.9 64-bit",
   "metadata": {
    "interpreter": {
     "hash": "08bf4b89015df3fa62ec3e5cebfe3c3e5a181176f7e37ebd57af46c3a62ed99e"
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "source": [
    "# 批量归一化层\n",
    "\n",
    "## 全连接层做批量归一化\n",
    "\n",
    "通常，将批量归一化层置于全连接层中的仿射变换和激活函数之间。\n",
    "\n",
    "## 卷积层做批量归一化\n",
    "\n",
    "发生在卷积计算之后、应用激活函数之前。\n",
    "\n",
    "## 预测时的批量归一化\n"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import torch.nn.functional as F\n",
    "import d2lzh_pytorch as d2l\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):\n",
    "    if not is_training:\n",
    "        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)\n",
    "    \n",
    "    else:\n",
    "        assert len(X.shape) in (2, 4)\n",
    "        if len(X.shape) == 2:\n",
    "            # 全连接层，计算特征维上的均值和方差\n",
    "            mean = X.mean(dim=0)\n",
    "            var = ((X - mean) ** 2).mean(dim=0)\n",
    "        \n",
    "        else:\n",
    "            mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)\n",
    "            var = ((X - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)\n",
    "        \n",
    "        X_hat = (X - mean) / torch.sqrt(var + eps)\n",
    "        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean\n",
    "        moving_var = momentum * moving_var + (1.0 - momentum) * var\n",
    "    \n",
    "    Y = gamma * X_hat + beta # 拉伸和平移\n",
    "    return Y, moving_mean, moving_var\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BatchNorm(nn.Module):\n",
    "    def __init__(self, num_features, num_dims):\n",
    "        super(BatchNorm, self).__init__()\n",
    "        if num_dims == 2:\n",
    "            shape = (1, num_features)\n",
    "        else:\n",
    "            shape = (1, num_features, 1, 1)\n",
    "\n",
    "        # 参与求梯度和迭代的拉伸和偏移参数，分别初始化成0和1\n",
    "        self.gamma = nn.Parameter(torch.ones(shape))\n",
    "        self.beta = nn.Parameter(torch.zeros(shape))\n",
    "        # 不参与求梯度和迭代的变量，全在内存上初始化成0\n",
    "        self.moving_mean = torch.zeros(shape)\n",
    "        self.moving_var = torch.zeros(shape)\n",
    "\n",
    "    def forward(self, X):\n",
    "        if self.moving_mean.device != X.device:\n",
    "            self.moving_mean = self.moving_mean.to(X.device)\n",
    "            self.moving_var = self.moving_var.to(X.device)\n",
    "\n",
    "        Y, self.moving_mean, self.moving_var = batch_norm(self.training, X, self.gamma, self.beta, self.moving_mean, self.moving_var, eps=1e-5, momentum=0.9)\n",
    "\n",
    "        return Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 1.0037, train acc 0.783, test acc 0.836, time 15.5 sec\n",
      "epoch 2, loss 0.4576, train acc 0.864, test acc 0.849, time 13.3 sec\n",
      "epoch 3, loss 0.3671, train acc 0.879, test acc 0.854, time 13.5 sec\n",
      "epoch 4, loss 0.3282, train acc 0.887, test acc 0.837, time 13.4 sec\n",
      "epoch 5, loss 0.3047, train acc 0.894, test acc 0.846, time 13.2 sec\n"
     ]
    }
   ],
   "source": [
    "net = nn.Sequential(\n",
    "    nn.Conv2d(1, 6, 5),\n",
    "    BatchNorm(6, num_dims=4),\n",
    "    nn.Sigmoid(),\n",
    "    nn.MaxPool2d(2, 2),\n",
    "    nn.Conv2d(6, 16, 5),\n",
    "    BatchNorm(16, num_dims=4),\n",
    "    nn.Sigmoid(),\n",
    "    nn.MaxPool2d(2, 2),\n",
    "    d2l.FlattenLayer(),\n",
    "    nn.Linear(16*4*4, 120),\n",
    "    BatchNorm(120, num_dims=2),\n",
    "    nn.Sigmoid(),\n",
    "    nn.Linear(120, 84),\n",
    "    BatchNorm(84, num_dims=2),\n",
    "    nn.Sigmoid(),\n",
    "    nn.Linear(84, 10)\n",
    ")\n",
    "\n",
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)\n",
    "\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "(tensor([1.1662, 1.1112, 0.9872, 1.0552, 1.0212, 1.1489], device='cuda:0',\n",
       "        grad_fn=<ViewBackward>),\n",
       " tensor([ 0.2454,  0.0990,  0.0874, -0.5748,  0.4119,  0.4154], device='cuda:0',\n",
       "        grad_fn=<ViewBackward>))"
      ]
     },
     "metadata": {},
     "execution_count": 4
    }
   ],
   "source": [
    "net[1].gamma.view((-1,)), net[1].beta.view((-1,))"
   ]
  },
  {
   "source": [
    "# Pytorch自带BatchNorm2d"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.9908, train acc 0.790, test acc 0.843, time 11.7 sec\n",
      "epoch 2, loss 0.4552, train acc 0.864, test acc 0.833, time 10.8 sec\n",
      "epoch 3, loss 0.3676, train acc 0.879, test acc 0.854, time 11.4 sec\n",
      "epoch 4, loss 0.3325, train acc 0.886, test acc 0.867, time 11.0 sec\n",
      "epoch 5, loss 0.3096, train acc 0.892, test acc 0.873, time 11.0 sec\n"
     ]
    }
   ],
   "source": [
    "net = nn.Sequential(\n",
    "    nn.Conv2d(1, 6, 5),\n",
    "    nn.BatchNorm2d(6),\n",
    "    nn.Sigmoid(),\n",
    "    nn.MaxPool2d(2, 2),\n",
    "    nn.Conv2d(6, 16, 5),\n",
    "    nn.BatchNorm2d(16),\n",
    "    nn.Sigmoid(),\n",
    "    nn.MaxPool2d(2, 2),\n",
    "    d2l.FlattenLayer(),\n",
    "    nn.Linear(16*4*4, 120),\n",
    "    nn.BatchNorm1d(120),\n",
    "    nn.Sigmoid(),\n",
    "    nn.Linear(120, 84),\n",
    "    nn.BatchNorm1d(84),\n",
    "    nn.Sigmoid(),\n",
    "    nn.Linear(84, 10)\n",
    ")\n",
    "\n",
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)\n",
    "\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}